{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-26T01:58:12.294665Z",
     "start_time": "2019-05-26T01:58:10.019832Z"
    },
    "code_folding": [],
    "run_control": {
     "marked": false
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import time\n",
    "\n",
    "from Util import my_helper\n",
    "import copy\n",
    "import itertools\n",
    "import random\n",
    "import pickle as cPickle\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow.python.layers import core as core_layers\n",
    "\n",
    "\n",
    "from Model import LM\n",
    "from Util.myAttWrapper import SelfAttWrapper\n",
    "from Util import myResidualCell\n",
    "from Util.bleu import BLEU\n",
    "from Util.myUtil import *\n",
    "\n",
    "\n",
    "tf.logging.set_verbosity(tf.logging.INFO)\n",
    "sess_conf = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-26T01:58:12.309216Z",
     "start_time": "2019-05-26T01:58:12.297026Z"
    }
   },
   "outputs": [],
   "source": [
    "def idx2str(idx):\n",
    "    return \" \".join(id2w[idxx] for idxx in idx)\n",
    "\n",
    "def str2idx(idx):\n",
    "    return [w2id[idxx] for idxx in idx]\n",
    "\n",
    "def _init_data(name):\n",
    "    w2id, id2w = cPickle.load(open('Data/%s/w2id_id2w.pkl' % name,'rb'))\n",
    "    X_indices = cPickle.load(open('Data/%s/index.pkl' % name,'rb'))\n",
    "    return X_indices, w2id, id2w\n",
    "\n",
    "def _train_model(name, lr=10.0, l1_reg_lambda=0.00, l2_reg_lambda=0.00, close_loss_rate=0.00):\n",
    "    qid_list = cPickle.load(open('Data/%s/qid_list.pkl'%name,'rb'))\n",
    "    qid_list = [w2id[w] for w in qid_list]\n",
    "    rnn_size = dict()\n",
    "    rnn_size['Poem'] =  512\n",
    "    rnn_size['Daily'] = 512 \n",
    "    rnn_size['APRC'] = 1024\n",
    "    \n",
    "    num_layer = dict()\n",
    "    num_layer['Poem'] = 2\n",
    "    num_layer['Daily'] = 1\n",
    "    num_layer['APRC'] = 1\n",
    "    \n",
    "    max_infer_length = dict()\n",
    "    max_infer_length['Poem'] = 33\n",
    "    max_infer_length['Daily'] = 50\n",
    "    max_infer_length['APRC'] = 36\n",
    "    \n",
    "    model_iter = dict()\n",
    "    model_iter['Poem'] = 30\n",
    "    model_iter['Daily'] = 30 \n",
    "    model_iter['APRC'] = 20\n",
    "    \n",
    "    assert name in ['Poem','Daily', 'APRC']\n",
    "\n",
    "    BATCH_SIZE = 256\n",
    "    NUM_EPOCH = 30\n",
    "    train_dir ='Model/%s' % name\n",
    "    dp = LM_DP(X_indices, w2id, BATCH_SIZE, n_epoch=NUM_EPOCH)\n",
    "    g = tf.Graph() \n",
    "    sess = tf.Session(graph=g, config=sess_conf) \n",
    "    with sess.as_default():\n",
    "        with sess.graph.as_default():\n",
    "            model = LM(\n",
    "                dp = dp,\n",
    "                rnn_size = rnn_size[name],\n",
    "                n_layers = num_layer[name],\n",
    "                decoder_embedding_dim = rnn_size[name],\n",
    "                cell_type='lstm',\n",
    "                close_loss_rate = close_loss_rate,\n",
    "                max_infer_length = max_infer_length[name],\n",
    "                att_type='B',\n",
    "                qid_list = qid_list,\n",
    "                lr = lr,\n",
    "                l1_reg_lambda = l1_reg_lambda,\n",
    "                l2_reg_lambda = l2_reg_lambda,\n",
    "                is_save = True,\n",
    "                residual = True,\n",
    "                is_jieba = False,\n",
    "                sess=sess\n",
    "            )\n",
    "\n",
    "\n",
    "    util = LM_util(dp=dp, model=model)\n",
    "    util.fit(train_dir=train_dir, is_bleu=False)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2019-05-26T01:58:09.594Z"
    }
   },
   "outputs": [],
   "source": [
    "task_name = 'APRC' # 'Daily', 'Poem'\n",
    "assert task_name in ['Poem','Daily', 'APRC']\n",
    "X_indices, w2id, id2w = _init_data(task_name)\n",
    "model = _train_model(task_name, lr=10.0)\n",
    "\n",
    "                "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "hide_input": false,
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
